Skip to content

GRPO Trainer#1020

Open
michaelbenayoun wants to merge 89 commits intomainfrom
grpo
Open

GRPO Trainer#1020
michaelbenayoun wants to merge 89 commits intomainfrom
grpo

Conversation

@michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Nov 4, 2025

What does this PR do?

This PR adds partial support for GRPO.

It was broken down into smaller PRs:

It adds the NeuronGRPOTrainer with a set of optimizations and modifications for the Torch XLA backend used to run things on Trainium instances. There are still core missing features:

  • Integration with vLLM: we use a custom CPU vLLM hack for now. The plan is to work on the vLLM part on another PR.
  • Weight Synchronization NeuronGRPOTrainer <-> vLLM
  • No tensor parallelism

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions bot removed the Stale label Jan 31, 2026
Comment on lines +59 to +63
if not self.experimental:
raise ValueError(
"NeuronGRPOTrainer is experimental and not production-ready. To proceed, set `experimental=True` in "
"your NeuronGRPOConfig. This flag exists to ensure users are aware of the current state of the implementation."
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now we disable the access to the NeuronGRPOTrainer

@michaelbenayoun michaelbenayoun marked this pull request as ready for review February 4, 2026 17:16
dacorvo
dacorvo previously requested changes Feb 5, 2026
Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change in the CI workflow aside, this looks good to me, although I did not go into the details of the trainer algorithm.
For the next step, you will need to add a load_state_dict method in NxDPretrainedModel, as the existing load_weights method reads weights from a path. You will need to provide also modified checkpoint_loader_fn to load weights from a state_dict directly as for now it also expects to load the state_dict from a path.
Then in vLLM you will need to call the load_state_dict method when required.

run: |
source aws_neuron_venv_pytorch/bin/activate
python -m pip install .[neuronx,tests]
python -m pip install .[neuronx,tests,training]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not install training requirements for all workflows:

  • this can create conflicts
  • this will hide any imports errors

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds partial support for GRPO (Group Relative Policy Optimization) training on Neuron (Trainium) devices through the new NeuronGRPOTrainer class. The implementation includes XLA-specific optimizations and modifications to work with the Torch XLA backend, though several core features remain unimplemented (vLLM integration, weight synchronization, tensor parallelism).

Changes:

  • Adds NeuronGRPOTrainer with XLA-optimized implementations for generation, scoring, and loss computation
  • Introduces NeuronGRPOConfig for configuration with experimental flag requirement
  • Implements XLA-friendly utility functions (padding, entropy, statistical operations) in trl_utils.py
  • Adds custom vLLM client implementations with CPU communicator and mock client for testing
  • Updates NeuronTrainer to support _prepare_inputs hook and replaces xm.mark_step() with torch_xla.sync()
  • Modifies LoRA transformation utilities to handle missing weights more gracefully

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 19 comments.

Show a summary per file
File Description
optimum/neuron/trainers/grpo_trainer.py Core GRPO trainer implementation with XLA optimizations (1414 lines, new file)
optimum/neuron/trainers/grpo_config.py Configuration class with validation and experimental flag (118 lines, new file)
optimum/neuron/trainers/trl_utils.py XLA-optimized utility functions for padding, statistics, and sampling (270 lines)
optimum/neuron/trainers/extras/vllm_client.py Custom vLLM clients for Neuron with CPU communicator and mock implementation (213 lines, new file)
optimum/neuron/trainers/transformers.py Updates to NeuronTrainer for _prepare_inputs hook and torch_xla.sync() migration
optimum/neuron/trainers/utils.py Adds move_inputs_to_device utility and updates XLAPrefetchIterator
optimum/neuron/models/training/transformations_utils.py Converts LoRA weight errors to silent skips for flexibility
optimum/neuron/trainers/metrics/collector.py Refactors get_metric_unit for cleaner logic
optimum/neuron/utils/init.py Exports is_vllm_available function
optimum/neuron/init.py Exports NeuronGRPOTrainer and NeuronGRPOConfig
.github/actions/install_optimum_neuron/action.yml Adds training extras to CI installation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +90 to +112
raise Exception(f"Request failed: {response.status_code}, {response.text}")

world_size = vllm_world_size + 1 # add the client to the world
self.rank = vllm_world_size # the client's rank is the last process

# Initialize weight update group
url = f"{self.base_url}/init_communicator/"

# Use dummy UUID for CPU/Neuron environments
client_device_uuid = "42"

# In the server side, the host is set to 0.0.0.0
response = self.session.post(
url,
json={
"host": "0.0.0.0",
"port": self.group_port,
"world_size": world_size,
"client_device_uuid": client_device_uuid,
},
)
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a bare Exception type is not recommended. Use a more specific exception type like RuntimeError or create a custom exception class for better error handling and debugging. This applies to both lines 90 and 112.

Copilot uses AI. Check for mistakes.
Comment on lines +134 to +143
def __init__(self, tokenizer, max_completion_length=256, min_completion_length=10, seed=None):
self.tokenizer = tokenizer
self.max_completion_length = max_completion_length
self.min_completion_length = min(min_completion_length, max_completion_length)
self.random = random.Random(seed)

logger.warning(
"Using MockVLLMClient for neuron_parallel_compile or testing. "
"This generates echo completions and should only be used for compilation/testing."
)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MockVLLMClient inherits from VLLMClient but doesn't call super().__init__(). This means parent class initialization is skipped, which could cause issues if the parent class (TRLVLLMClient via VLLMClient) expects certain attributes to be initialized. The parent class likely sets up self.session, self.base_url, self.host, self.group_port and other attributes that may be accessed. Consider either calling super().__init__() with appropriate parameters or inheriting directly from object if the parent's initialization is not needed.

Copilot uses AI. Check for mistakes.
Compute the minimum value of a tensor, ignoring NaNs.
"""
mask = torch.isnan(tensor)
filled = torch.where(mask, torch.tensor(float("inf"), device=tensor.device), tensor)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a new tensor with torch.tensor(float("inf"), device=tensor.device) for each call can cause XLA graph fragmentation. Consider pre-creating constant tensors during initialization (similar to _one_float, _inf_float in the trainer) and reusing them, or use filled.new_full((1,), float("inf"))[0] which reuses the existing tensor's properties.

Copilot uses AI. Check for mistakes.
num_items_in_batch,
sampling_per_token_logps_list,
forward_kwargs,
) = self._generate(prompts, images)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method _generate is called but not defined in this class. This will cause an AttributeError at runtime. The method should either be defined in this class or inherited from the parent GRPOTrainer class via the _GRPOTrainer intermediate class. Based on the usage, it should return a tuple of (prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, forward_kwargs).

Copilot uses AI. Check for mistakes.

# Gradient accumulation requires scaled loss
self.model_accepts_loss_kwargs = False

Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _tag_names attribute is accessed but never defined in this class. This will cause an AttributeError at runtime if the model has the add_model_tags method. This attribute should be defined during initialization or inherited from a parent class.

Suggested change
# Ensure _tag_names exists before being passed to model.add_model_tags.
if not hasattr(self, "_tag_names"):
self._tag_names = set()

Copilot uses AI. Check for mistakes.
Comment on lines +176 to +196
prompt_ids.append(prompt_tokens)

# Generate n completions per prompt
for _ in range(n):
# Random completion length within bounds
max_len = min(max_tokens, self.max_completion_length)
completion_length = self.random.randint(self.min_completion_length, max_len)

# Echo mode: cycle through prompt tokens
if len(prompt_tokens) > 0:
completion = [prompt_tokens[i % len(prompt_tokens)] for i in range(completion_length)]
else:
# Fallback if prompt is empty
completion = [fallback_token_id] * completion_length

completion_ids.append(completion)

# Logprobs: simulate higher confidence for echoed tokens
completion_logprobs = [-self.random.uniform(0.5, 2.0) for _ in range(completion_length)]
logprobs.append(completion_logprobs)

Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MockVLLMClient generates n completions for each prompt, but only appends each prompt to prompt_ids once (line 176). This means prompt_ids will have length equal to the number of prompts, while completion_ids will have length equal to num_prompts * n. This mismatch in list lengths could cause issues if the caller expects both lists to have the same length. Based on the usage in _generate_single_turn, it appears the expected behavior is for prompt_ids to be repeated for each completion.

Copilot uses AI. Check for mistakes.
# Send weights to vLLM server (only main process for server mode)
for name, weight in original_weights.items():
# Clean up parameter name for vLLM
name = self._fix_param_name_to_vllm(name)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method _fix_param_name_to_vllm is called but not defined in this class or any visible parent class. This will cause an AttributeError when executing this code path. The method should be implemented or inherited from a parent class. Based on the context, it appears this method should clean up parameter names for vLLM compatibility.

Copilot uses AI. Check for mistakes.
Comment on lines 703 to +704
if to_concat_and_duplicate_name is None or to_unfuse_name is None:
raise ValueError(
f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}."
)
continue
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous issue, this converts a hard error into a silent skip. This could hide configuration problems. Consider logging when weights are not found to aid debugging.

Copilot uses AI. Check for mistakes.
Comment on lines +690 to +695
# TODO: Currently not supported, to implement asap in later PRs with vLLM integration.
# if self.vllm_mode == "server" and self.accelerator.is_main_process:
# self.vllm_client.update_named_param(name, weight)
# elif self.vllm_mode == "colocate":
# llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
# llm_model.load_weights([(name, weight)])
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Suggested change
# TODO: Currently not supported, to implement asap in later PRs with vLLM integration.
# if self.vllm_mode == "server" and self.accelerator.is_main_process:
# self.vllm_client.update_named_param(name, weight)
# elif self.vllm_mode == "colocate":
# llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
# llm_model.load_weights([(name, weight)])
# TODO: Support updating vLLM weights for NeuronPeftModel in server and colocate modes.
# This will be implemented in a future PR as part of the vLLM integration work.

Copilot uses AI. Check for mistakes.
Comment on lines +688 to +690
name = self._fix_param_name_to_vllm(name)

# TODO: Currently not supported, to implement asap in later PRs with vLLM integration.
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable name is not used.

Suggested change
name = self._fix_param_name_to_vllm(name)
# TODO: Currently not supported, to implement asap in later PRs with vLLM integration.
# TODO: Currently not supported, to implement asap in later PRs with vLLM integration.
# name = self._fix_param_name_to_vllm(name)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No more blockers, but will review in more details tomorrow. Copilot detected some issues that may be considered.

@dacorvo dacorvo dismissed their stale review February 5, 2026 18:38

Blocker addressed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants